import yaml
import logging
logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)

import torch
from diffusers import DDIMScheduler
from datasets import DatasetDict

from main.wmdiffusion import WMDetectStableDiffusionPipeline
from main.utils import *
from main.dataset import *
from main.nf_flow_models import *

import argparse

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  # This can be set to True for faster performance, but may not be deterministic


def init(cfgs, device):
    logging.info(f'===== Init Stable Diffusion Pipeline =====')
    scheduler = DDIMScheduler.from_pretrained(cfgs['model_id'], subfolder="scheduler")
    pipe = WMDetectStableDiffusionPipeline.from_pretrained(cfgs['model_id'], scheduler=scheduler).to(device)
    pipe.set_progress_bar_config(disable=True)
    return pipe

def get_init_latent(img_tensor, pipe, text_embeddings, guidance_scale=1.0):
    # DDIM inversion from the given image
    img_latents = pipe.get_image_latents(img_tensor, sample=False)
    reversed_latents = pipe.forward_diffusion(
        latents=img_latents,
        text_embeddings=text_embeddings,
        guidance_scale=guidance_scale,
        num_inference_steps=50,
    )
    return reversed_latents

def main(args):
    # dataloader
    logging.info(f'===== Load Config =====')
    with open(args.cfg_path, 'r') as file:
        cfgs = yaml.safe_load(file)
    logging.info(cfgs)

    device = torch.device('cuda')

    # get dataset
    dataset = get_dataset(args, is_train=args.is_train)

    # define transformation for making images the same size
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((512, 512))
    ])

    # setup components
    pipe = init(cfgs, device)
    prompt = [''] * args.batch_size
    empty_text_embeddings = pipe.get_text_embedding(prompt)

    def add_tensors(example):
        gt_img_tensor = torch.stack([transform(example['image'].convert('RGB'))]).to(device)
        init_latents_approx = get_init_latent(gt_img_tensor, pipe, empty_text_embeddings)
        init_latents_approx = init_latents_approx.detach().cpu()
        
        return {'latents': init_latents_approx}

    torch.cuda.empty_cache()

    new_dataset = dataset.map(add_tensors, batched=False)

    prefix = 'test'
    if args.is_train:
        prefix = 'train'
    save_path = os.path.join(args.dataset_path, args.dataset + f'_{prefix}_latents')
    new_dataset.save_to_disk(save_path)


def save_joint_train_test(train_dataset_path, test_dataset_path, save_path):
    train_data = load_from_disk(train_dataset_path)
    test_data = load_from_disk(test_dataset_path)

    joint_train_test_data = DatasetDict({'train': train_data, 'test': test_data})
    joint_train_test_data.save_to_disk(save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='diffusion watermark')
    parser.add_argument ('--dataset', default='diffusiondb', choices=['coco', 'diffusiondb', 'wikiart'])
    parser.add_argument ('--dataset_path', default='/localhome/data/datasets/watermarking/training')
    parser.add_argument ('--seed', default=0, type=int)
    parser.add_argument ('--cfg_path', default='./example/config/config.yaml')
    parser.add_argument ('--batch_size', default=1, type=int)
    # parser.add_argument ('--is_train', default=1, type=int)
    parser.add_argument('--is_train', action="store_true")

    args = parser.parse_args()

    # set seed
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    main(args)

